package https

import (
	"bytes"
	"compress/gzip"
	"io"
	"math/rand"
	"net"
	"net/http"
	"net/http/httputil"
	"net/url"
	"strconv"
	"strings"
	"time"
)

// 传入的接口信息，用于处理响应的回调操作
type GatwayInterface interface {
	// 超时时间
	Timeout() time.Duration

	// 长连接超时时间
	KeepAlive() time.Duration

	// TLS握手超时时间
	TLSHandshakeTimeout() time.Duration

	// 负载均衡的URL列表[此处采用随机的方式进行请求访问]
	Urls() []string

	// 请求的网址信息[可做额外处理，如追加header参数等](此处不建议重写URL)
	// 追加Header方法：request.Header.Set("","")
	QuestUrl(request *http.Request)

	// 响应处理[可做额外处理，如重写返回信息等](socket长连接不支持此方法)
	// 重写方法：将重写结果作为[]byte进行返回(若此值为nul则表示不进行重写)
	ResponseUrl(response *http.Response) ([]byte, error)

	// 错误处理方法
	// 错误回调 ：关闭real_server时测试，错误回调
	// 范围：transport.RoundTrip发生的错误、以及ModifyResponse发生的错误
	Error(w http.ResponseWriter, r *http.Request, err error)
}

// 网址转发操作【此操作暂不支持通配网址信息】
//
//	h	原始请求处理器
//	intef	网址转发配置
func Gatway(h http.Handler, intef GatwayInterface) http.Handler {
	targets := parseUrls(intef.Urls())

	// 转发配置
	transport := &http.Transport{
		DialContext: (&net.Dialer{
			Timeout:   intef.Timeout(),   //连接超时
			KeepAlive: intef.KeepAlive(), //长连接超时时间
		}).DialContext,
		TLSHandshakeTimeout: intef.TLSHandshakeTimeout(), //tls握手超时时间
	}

	// 请求协调者
	director := func(req *http.Request) {
		if len(targets) == 0 {
			return
		}
		target := targets[rand.Intn(len(targets))]
		req.URL.Scheme = target.Scheme
		req.URL.Host = target.Host
		req.URL.Path = target.Path
		req.Host = target.Host
		req.URL.RawQuery = mergeQuery(target.RawQuery, req.URL.RawQuery)
		if _, ok := req.Header["User-Agent"]; !ok {
			req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/99.0.4844.84 Safari/537.36 HBPC/12.1.3.303")
		}
		intef.QuestUrl(req)
	}

	// 更改内容
	modifyFunc := func(resp *http.Response) error {
		if strings.Contains(resp.Header.Get("Connection"), "Upgrade") {
			return nil
		}
		payload, err := readResponseBody(resp)
		if err != nil {
			return err
		}

		rep_cont, err := intef.ResponseUrl(resp)
		if err != nil {
			return err
		}
		if rep_cont != nil {
			payload = rep_cont
		}

		resp.Body = io.NopCloser(bytes.NewBuffer(payload))
		resp.ContentLength = int64(len(payload))
		resp.Header.Set("Content-Length", strconv.FormatInt(int64(len(payload)), 10))
		return nil
	}

	return &httputil.ReverseProxy{Director: director, Transport: transport, ModifyResponse: modifyFunc, ErrorHandler: intef.Error}
}

// 解析URL列表
//
//	urls	原始URL列表
func parseUrls(urls []string) []*url.URL {
	targets := []*url.URL{}
	for _, v := range urls {
		if u, err := url.Parse(v); err == nil {
			targets = append(targets, u)
		}
	}
	return targets
}

// 合并查询参数
//
//	query1	原始查询参数
//	query2	追加的查询参数
func mergeQuery(query1, query2 string) string {
	if query1 == "" {
		return query2
	}
	if query2 == "" {
		return query1
	}
	return query1 + "&" + query2
}

// 读取响应体内容
//
//	resp	响应体
func readResponseBody(resp *http.Response) ([]byte, error) {
	var payload []byte
	var readErr error
	if strings.Contains(resp.Header.Get("Content-Encoding"), "gzip") {
		gr, err := gzip.NewReader(resp.Body)
		if err != nil {
			return nil, err
		}
		payload, readErr = io.ReadAll(gr)
		resp.Header.Del("Content-Encoding")
	} else {
		payload, readErr = io.ReadAll(resp.Body)
	}
	return payload, readErr
}
